import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import torch
import pickle
from stanfordcorenlp import StanfordCoreNLP

nlp = StanfordCoreNLP('E:/stanford-corenlp-full-2016-10-31', lang='en')
class MyDataset(Dataset):

    def __init__(self,en_data,zh_data,en_w2i,zh_w2i):
        self.en_data=en_data
        self.zh_data=zh_data
        self.en_w2i=en_w2i
        self.zh_w2i=zh_w2i

    def __getitem__(self,idx):
        en = self.en_data[idx]
        zh = self.zh_data[idx]

        en_i = [self.en_w2i[i] for i in en]
        zh_i = [self.zh_w2i[i] for i in zh]



        return en_i,zh_i




    def __len__(self):

        return len(self.zh_data)




class Encoder(nn.Module):
    def __init__(self,encoder_embedding_num, encoder_hidden_num,encoder_corpus_num):
        super(Encoder,self).__init__()
        self.embedding = nn.Embedding(encoder_corpus_num,encoder_embedding_num)#nn.embedding就是一个字典映射表，比如它的大小是128，0~127每个位置都存储着一个长度为3的数组，那么我们外部输入的值可以通过index (0~127)映射到每个对应的数组上，所以不管外部的值是如何都能在该nn.embedding中找到对应的数组。想想哈希表，就很好理解了。nn.Embedding初始化一个行为encoder_corpus_num，列为encoder_embedding_num的矩阵，用于后面与输入数据做词向量
                                                                               #
        self.lstm = nn.LSTM(encoder_embedding_num,encoder_hidden_num,batch_first=True)


    def forward(self,en_index):
        en_embedding = self.embedding(en_index)#en_index输入已经实例化的self.embedding后会对每个词进行词嵌入，嵌入后词的维度为encoder_embedding_num的词向量
        _, encoder_hidden = self.lstm(en_embedding)

        return encoder_hidden



def translate(sentence):
    temp = nlp.word_tokenize(sentence)

    global en_word2index, model
    en_index = torch.tensor([[en_word2index[i] for i in temp]], device = device)


    result = []
    encoder_hidden = model.encoder(en_index)
    decoder_input = torch.tensor([[zh_word2index['<BOS>']]], device = device)

    decoder_hidden = encoder_hidden

    while True:
        decoder_output, decoder_hidden = model.decoder(decoder_input, decoder_hidden)
        pre = model.liner(decoder_output)

        word_index = int(torch.argmax(pre, dim = -1))

        word = zh_index2word[word_index]

        if word == '<EOS>' or len(result) > 50:
            break

        result.append(word)
        decoder_input = torch.tensor([[word_index]], device = device)

    print("译文： ", "".join(result))






class Decoder(nn.Module):
    def __init__(self,decoder_embedding_num, decoder_hidden_num,decoder_corpus_num):
        super(Decoder,self).__init__()
        self.embedding = nn.Embedding(decoder_corpus_num,decoder_embedding_num)
        self.lstm = nn.LSTM(decoder_embedding_num,decoder_hidden_num,batch_first=True)

    def forward(self,decoder_input, hidden):

        embedding = self.embedding(decoder_input)
        decoder_output,decoder_hidden = self.lstm(embedding, hidden)
        return decoder_output,decoder_hidden



class Seq2Seq(nn.Module):
    def __init__(self,encoder_embedding_num, encoder_hidden_num,encoder_corpus_num, decoder_embedding_num, decoder_hidden_num,decoder_corpus_num):
        super(Seq2Seq,self).__init__()
        self.encoder = Encoder(encoder_embedding_num,encoder_hidden_num,encoder_corpus_num)
        self.decoder = Decoder(decoder_embedding_num, decoder_hidden_num,decoder_corpus_num)
        self.liner = nn.Linear(decoder_hidden_num,decoder_corpus_num)
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        self.flatten_prediction = flatten = nn.Flatten(0,1)#合并prediction的0和1这两个维度
        self.flatten_label = flatten = nn.Flatten(0, 1)#合并label的0和1维度的数据

    def forward(self, en_index, zh_index):
        decoder_input = zh_index[:,: -1]#只取前面部分最后一个不要
        label = zh_index[:,1 :]#从第一个开始直到最后一个

        label = self.flatten_label(label)

        encoder_hidden = self.encoder(en_index)
        decoder_output,_ = self.decoder(decoder_input, encoder_hidden)

        prediction = self.liner(decoder_output)
        prediction = self.flatten_prediction(prediction)
        loss = self.cross_entropy_loss(prediction, label)#target的行数必须与input的行数相等，并且label的最大值不能高于input元素中的最大值

        return loss



if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    with open('zh_split.pickle', 'rb') as f:
        zh_split = pickle.load(f)

    with open('en_split.pickle', 'rb') as f:
        en_split = pickle.load(f)

    with open('en_word2index.pickle', 'rb') as f:
        en_word2index = pickle.load(f)

    with open('en_index2word.pickle', 'rb') as f:
        en_index2word = pickle.load(f)

    with open('zh_word2index.pickle', 'rb') as f:
        zh_word2index = pickle.load(f)


    with open('zh_index2word.pickle', 'rb') as f:
        zh_index2word = pickle.load(f)



    zh_vocab_size = len(zh_word2index)#zh_vocab = 88577
    en_vocab_size = len(en_word2index)#en_vocab = 67077

    zh_word2index.update({"<PAD>" : zh_vocab_size, "<BOS>" : zh_vocab_size + 1, "<EOS>" : zh_vocab_size + 2})
    en_word2index.update({"<PAD>" : en_vocab_size})

    zh_index2word_vocab_size = len(zh_index2word)
    zh_index2word.update({zh_index2word_vocab_size : "<PAD>", zh_index2word_vocab_size + 1 : "<BOS>", zh_index2word_vocab_size + 2 : "<EOS>"})

    en_index2word_vocab_size = len(en_index2word)
    en_index2word.update({en_index2word_vocab_size : "<PAD>"})


    def batch_data_process(batch_datas):
        global device

        en_index, zh_index = [], []
        en_len, zh_len = [], []

        for en, zh in batch_datas:
            en_index.append(en)
            zh_index.append(zh)
            en_len.append(len(en))
            zh_len.append(len(zh))
        #求minibatch中长度最长的序列
        max_len_en = max(en_len)
        max_len_zh = max(zh_len)
        # 使该minibatch中最短的序列向最长的序列补齐
        en_index = [i + [en_word2index["<PAD>"]] * (max_len_en - len(i)) for i in en_index]
        zh_index = [[zh_word2index["<BOS>"]] + i + [zh_word2index["<EOS>"]] + [zh_word2index["<PAD>"]] * (max_len_zh - len(i)) for i in
                    zh_index]

        en_index = torch.tensor(en_index, device=device)
        zh_index = torch.tensor(zh_index, device=device)

        return en_index, zh_index

    encoder_embedding_num = 1000#输入数据中每个元素的维度，
    encoder_hidden_num = 512
    encoder_corpus_num = 67078#英文词表的长度

    decoder_embedding_num = 1000
    decoder_hidden_num = 512
    decoder_corpus_num = 88580#中文词表的长度


    batch_size=32
    epochs=100
    lr=0.001


    #只使用前1000个样本进行训练
    indices = torch.randperm(2000)[:1000]
    sampler = SubsetRandomSampler(indices)


    data = MyDataset(en_split,zh_split,en_word2index,zh_word2index)
    dataloader = DataLoader(data,batch_size,shuffle=False,collate_fn=batch_data_process, sampler=sampler)
    #en_word2index长度=67078，zh_word2index长度=88580

    model = Seq2Seq(encoder_embedding_num, encoder_hidden_num,encoder_corpus_num,decoder_embedding_num, decoder_hidden_num,decoder_corpus_num)#encoder_corpus_num为词嵌入矩阵的行（英文词表的行数），encoder_embedding_num为词嵌入矩阵的列
    model = model.to(device)

    opt = torch.optim.Adam(model.parameters(), lr=lr)
    for e in range(epochs):

        for en_index, zh_index in dataloader:
            loss = model(en_index, zh_index)
            loss.backward()
            opt.step()
            opt.zero_grad()
        print(f"loss : {loss:.3f}")

    with open('my_model.pkl', 'wb') as f:
        pickle.dump(model, f)

    while True:

        sentence = input("请输入一句英文：")

        translate(sentence)


#2025.02.07总结：
#取1000个样本进行训练，100个epoch后loss下降到0.007，之后进行翻译测试。测试结果无误。